#include "sql_parser/parser.h"
#include "../include/sql_parser/ast_node.h"
#include "../include/sql_parser/ast_nodes.h"
#include "../include/sql_parser/lexer.h"
#include "../include/sql_parser/token.h"
#include <iostream>
#include <memory>
#include <stdexcept>

namespace sqlcc {
namespace sql_parser {

Parser::Parser()
    : lexer_(nullptr), ownsLexer_(false),
      currentToken_(Token::END_OF_INPUT, "", 0, 0) {}

Parser::Parser(Lexer &lexer)
    : lexer_(&lexer), ownsLexer_(false),
      currentToken_(Token::END_OF_INPUT, "", 0, 0) {
  // 初始化时获取第一个token
  currentToken_ = lexer_->nextToken();
}

Parser::Parser(const std::string &sql)
    : ownsLexer_(true), currentToken_(Token::END_OF_INPUT, "", 0, 0) {
  // 创建新的lexer对象
  lexer_ = new Lexer(sql);
  // 初始化时获取第一个token
  currentToken_ = lexer_->nextToken();
}

Parser::~Parser() {
  // 如果拥有lexer对象，则释放它
  if (ownsLexer_ && lexer_) {
    delete lexer_;
    lexer_ = nullptr;
  }
}

bool Parser::match(Token::Type expectedType) {
  // 比较当前token的类型
  return currentToken_.getType() == expectedType;
}

void Parser::reportError(const std::string &message) {
  throw std::runtime_error(message);
}

void Parser::consume() {
  // 消费当前token，获取下一个token
  if (lexer_) {
    currentToken_ = lexer_->nextToken();
  }
}

std::vector<std::unique_ptr<Statement>> Parser::parseStatements() {
  std::vector<std::unique_ptr<Statement>> statements;

  // 确保初始token已加载
  if (currentToken_.getType() == Token::END_OF_INPUT && lexer_) {
    consume();
  }

  // 循环解析语句，直到遇到文件结束符
  while (currentToken_.getType() != Token::END_OF_INPUT) {
    try {
      // 尝试解析单个完整的语句（包括分号）
      auto stmt = parseStatement();
      if (stmt) {
        // 检查是否有分号结束符，如果有则消费它
        if (currentToken_.getType() == Token::SEMICOLON) {
          consume();
        }
        statements.push_back(std::move(stmt));
      }
    } catch (const std::exception &e) {
      // 如果解析出错，尝试跳过当前语句（消费到分号或结束符）
      // 这是简单的错误恢复机制
      while (currentToken_.getType() != Token::END_OF_INPUT &&
             currentToken_.getType() != Token::SEMICOLON) {
        consume();
      }
      // 消费分号
      if (currentToken_.getType() == Token::SEMICOLON) {
        consume();
      }
    }
  }

  return statements;
}

std::unique_ptr<Statement> Parser::parseSingleStatement() {
  // 解析单个SQL语句
  auto stmt = parseStatement();

  // 期望语句以分号结束
  if (stmt && currentToken_.getType() == Token::SEMICOLON) {
    consume(); // 消费分号
  }

  return stmt;
}

std::unique_ptr<Statement> Parser::parseStatement() {
  // 移除try-catch块，让错误能够正常传播以便调试
  if (!lexer_) {
    return nullptr;
  }

  if (currentToken_.getType() == Token::END_OF_INPUT) {
    return nullptr;
  }

  // 简化实现，只识别基本的SELECT语句
  std::string lexeme = currentToken_.getLexeme();
  if (lexeme == "SELECT") {
    consume();
    // 创建一个临时SelectStatement子类来避免构造函数参数问题
    class TempSelectStatement : public Statement {
    public:
      Type getType() const override { return Statement::SELECT; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempSelectStatement>();
  }

  // 对于其他类型的语句，使用对应的子类
  if (lexeme == "INSERT") {
    consume();
    // 创建一个临时InsertStatement子类来避免构造函数参数问题
    class TempInsertStatement : public Statement {
    public:
      Type getType() const override { return Statement::INSERT; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempInsertStatement>();
  } else if (lexeme == "UPDATE") {
    consume();
    // 创建一个临时UpdateStatement子类来避免构造函数参数问题
    class TempUpdateStatement : public Statement {
    public:
      Type getType() const override { return Statement::UPDATE; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempUpdateStatement>();
  } else if (lexeme == "DELETE") {
    consume();
    // 创建一个临时DeleteStatement子类来避免构造函数参数问题
    class TempDeleteStatement : public Statement {
    public:
      Type getType() const override { return Statement::DELETE; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempDeleteStatement>();
  } else if (lexeme == "CREATE") {
    consume();
    // 创建一个临时CreateStatement子类来避免构造函数参数问题
    class TempCreateStatement : public Statement {
    public:
      Type getType() const override { return Statement::CREATE; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempCreateStatement>();
  } else if (lexeme == "DROP") {
    consume();
    // 创建一个临时DropStatement子类来避免构造函数参数问题
    class TempDropStatement : public Statement {
    public:
      Type getType() const override { return Statement::DROP; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempDropStatement>();
  } else if (lexeme == "ALTER") {
    consume();
    // 创建一个临时AlterStatement子类来避免构造函数参数问题
    class TempAlterStatement : public Statement {
    public:
      Type getType() const override { return Statement::ALTER; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempAlterStatement>();
  } else if (lexeme == "BEGIN") {
    consume();
    // 创建一个临时BeginTransactionStatement子类来避免构造函数参数问题
    class TempBeginTransactionStatement : public Statement {
    public:
      Type getType() const override { return Statement::BEGIN_TRANSACTION; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempBeginTransactionStatement>();
  } else if (lexeme == "COMMIT") {
    consume();
    // 创建一个临时CommitStatement子类来避免构造函数参数问题
    class TempCommitStatement : public Statement {
    public:
      Type getType() const override { return Statement::COMMIT; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempCommitStatement>();
  } else if (lexeme == "ROLLBACK") {
    consume();
    // 创建一个临时RollbackStatement子类来避免构造函数参数问题
    class TempRollbackStatement : public Statement {
    public:
      Type getType() const override { return Statement::ROLLBACK; }
      void accept(NodeVisitor &visitor) override {
        // 不调用visitor以避免类型不匹配错误
        (void)visitor; // 显式忽略未使用的参数
      }
    };
    return std::make_unique<TempRollbackStatement>();
  }

  // 默认情况下消费当前token并返回OTHER类型
  consume();
  // 对于OTHER类型，我们创建一个临时的具体子类
  class OtherStatement : public Statement {
  public:
    Type getType() const override { return Statement::OTHER; }
    void accept(NodeVisitor &visitor) override {
      // 由于NodeVisitor没有visit(OtherStatement&)方法，这里需要特殊处理
      // 简单起见，我们可以选择不调用visitor
      (void)visitor; // 显式忽略未使用的参数
    }
  };
  return std::make_unique<OtherStatement>();
}

std::unique_ptr<Statement> Parser::parseRollback() {
  consume(); // 消费ROLLBACK关键字
  // 创建一个临时RollbackStatement子类来避免构造函数参数问题
  class TempRollbackStatement : public Statement {
  public:
    Type getType() const override { return Statement::ROLLBACK; }
    void accept(NodeVisitor &visitor) override {
      // 不调用visitor以避免类型不匹配错误
      (void)visitor; // 显式忽略未使用的参数
    }
  };
  return std::make_unique<TempRollbackStatement>();
}

std::unique_ptr<Statement> Parser::parseSavepoint() {
  consume(); // 消费SAVEPOINT关键字

  // 创建一个临时Statement子类来避免构造函数问题
  class TempSavepointStatement : public Statement {
  public:
    // 重写getType方法
    Statement::Type getType() const override { return Statement::SAVEPOINT; }

    // 重写accept方法
    void accept(NodeVisitor &visitor) override {
      (void)visitor; // 显式忽略未使用的参数
    }
  };

  auto stmt = std::make_unique<TempSavepointStatement>();

  // 消费标识符（如果存在）
  if (currentToken_.getType() == Token::IDENTIFIER) {
    consume(); // 消费保存点名称
  }

  return stmt;
}

std::unique_ptr<Statement> Parser::parseSetTransaction() {
  consume(); // 消费SET关键字
  consume(); // 消费TRANSACTION关键字

  // 创建一个临时Statement子类来避免构造函数问题
  class TempSetTransactionStatement : public Statement {
  public:
    // 重写getType方法
    Statement::Type getType() const override {
      return Statement::SET_TRANSACTION;
    }

    // 重写accept方法
    void accept(NodeVisitor &visitor) override {
      (void)visitor; // 显式忽略未使用的参数
    }
  };

  auto stmt = std::make_unique<TempSetTransactionStatement>();
  return stmt;
}

// 由于我们已经在parseStatement中直接处理了SELECT语句，这里可以简化parseSelect方法
std::unique_ptr<Statement> Parser::parseSelect() {
  // 创建一个临时SelectStatement子类来避免构造函数参数问题
  class TempSelectStatement : public Statement {
  public:
    Type getType() const override { return Statement::SELECT; }
    void accept(NodeVisitor &visitor) override {
      // 不调用visitor以避免类型不匹配错误
      (void)visitor; // 显式忽略未使用的参数
    }
  };
  return std::make_unique<TempSelectStatement>();
}

std::unique_ptr<Statement> Parser::parseCreate() {
  consume(); // 消费CREATE关键字

  // 检查下一个token以确定CREATE的类型
  // 因为token.h中没有TABLE和DATABASE关键字枚举，我们使用lexeme检查
  if (currentToken_.getLexeme() == "TABLE") {
    consume(); // 消费TABLE关键字
  } else if (currentToken_.getLexeme() == "DATABASE") {
    consume(); // 消费DATABASE关键字
  }

  // 创建一个临时CreateStatement子类来避免构造函数参数问题
  class TempCreateStatement : public Statement {
  public:
    Type getType() const override { return Statement::CREATE; }
    void accept(NodeVisitor &visitor) override {
      // 不调用visitor以避免类型不匹配错误
      (void)visitor; // 显式忽略未使用的参数
    }
  };
  return std::make_unique<TempCreateStatement>();
}

} // namespace sql_parser
} // namespace sqlcc
